package LiuYun

import spinal.core._
import spinal.lib._

object ICacheConfig extends CacheConfig{
    val w_idx = 8
    val w_pos = 2
    val n_way = 2
    val withBitMask = false
    val withDirty = false
}
class IMiss extends Bundle with IMasterSlave{
    val line   :Bool = Bool()
    val arvalid:Bool = Bool()
    val arready:Bool = Bool()
    val araddr :UInt = LISA.PA
    val rvalid :Bool = Bool()
    val rready :Bool = Bool()
    val rlast  :Bool = Bool()
    val rdata  :UInt = UInt(32 bits)
    override def asMaster():Unit = {
        out(arvalid,rready,line,araddr)
        in (arready,rvalid,rlast,rdata)
    }
}
class InstReq extends CacheReq
class InstResp extends CacheResp{
    def inst  :UInt = data
}
class ICacheCtrl extends Area{
    object IState extends SpinalEnum{
        val sIdle,sLookup,sWaitL2,sWaitArready,sMiss,sOperate = newElement()
    }
    import IState._
    // input fields
    val cancel     = Bool()
    val req_valid  = Bool()
    val miss_ready = Bool()
    val resp_ready = Bool()
    val recv_done  = Bool()
    val fill_done  = Bool()
    val cached     = Bool()
    val ex         = Bool()
    val hit        = Bool()

    val cop_valid  = Bool()
    val cop_ready  = Bool()
    val req_st_0_tag = Bool()
    val req_hit_inv = Bool()
    val start_hit_inv = cop_ready && req_hit_inv
    val st_0_tag   = cop_ready && req_st_0_tag
    val do_hit_inv = RegNext(start_hit_inv)
    val doing_cop  = do_hit_inv || req_hit_inv || req_st_0_tag

    val init = RamInit(8)
    val s  = Reg(IState()) init sIdle
    val addr_ok = Bool()
    val data_ok = Bool()
    val miss_valid = Bool()
    val send_miss  = miss_valid && miss_ready
    val quitable  = Bool()
    val read = Bool()
    val l1tlb_resp = new AddrTransResp
    val l2resp_valid = Bool()
    //val l2resp_ex = Bool()
    //val l2resp_cached = Bool()
    val l2tlb_willresp = Bool
    val l2tlb_ready = Bool

    val l2sent = RegInit(False)
    val wait_l2 = Bool()
    wait_l2 := False

    val l2req = Bool() setAsReg()
    l2req := False
    def acceptReq:Unit = {
        addr_ok  := !init.valid && !doing_cop
        cop_ready := !do_hit_inv
    }
    def sendMiss:Unit = {
        miss_valid := !cancel
    }
    def sendL2req():Unit = {
        l2req := !cancel //Reg
        wait_l2 := !cancel //Bool
    }
    addr_ok := False
    data_ok := False
    miss_valid := False
    quitable := True
    cop_ready := False
    switch(s){
        is(sIdle){
            l2sent := False
            acceptReq
        }
        is(sLookup){
            //when(l1tlb_resp.hit || ex || l2resp_valid){
            l2sent := False
            when(hit && cached && (l1tlb_resp.hit || l2resp_valid) || ex){
                data_ok := True
                when(resp_ready){
                    acceptReq
                }
            }.elsewhen(!l1tlb_resp.hit && !l2resp_valid){
                sendL2req()
                addr_ok := False
            }.otherwise{
                sendMiss
            }
            
        }
        is(sWaitL2){
            when(l2tlb_ready) {
                l2req := False
                l2sent := True
            }.elsewhen(!l2sent && !cancel){
                l2req := True
                //l2sent := l2tlb_ready
            }
        }
        is(sWaitArready){
            sendMiss
        }
        is(sMiss){
            quitable := False
            when(fill_done && recv_done){
                acceptReq
            }
        }
    }
    read := False
    when(quitable && cancel || init.valid){
        s := sIdle
    }.elsewhen(addr_ok){
        when(req_valid){
            read := True
            s := sLookup
        }.elsewhen(s =/= sIdle){
            s := sIdle
        }
    }.elsewhen(s === sWaitL2 && l2tlb_willresp){
        s := sLookup
        read := True
    }.elsewhen(send_miss){
        s := sMiss
    }.elsewhen(miss_valid && !miss_ready){
        s := sWaitArready
    }.elsewhen(start_hit_inv){
        read := True
    }.elsewhen(wait_l2){
        s := sWaitL2
    }
    def collect(miss:IMiss):Unit = {
        miss_ready := miss.arready
        miss.arvalid := miss_valid && !cancel
    }
    def collect(req:InstReq):Unit = {
        req_valid := req.valid
        req.allow := addr_ok
    }
    def collect(resp:InstResp):Unit = {
        cancel := resp.cancel
        resp_ready := resp.allow
    }
    def collect(recv:BusReceiver):Unit = {
        recv_done := recv.done
        recv.start := send_miss
    }
    def collect(fill:CacheFiller, iscached:Bool):Unit = {
        fill_done := fill.done
        fill.start := send_miss && iscached
    }
    def collect(info:ILookupInfo):Unit = {
        ex     := info.ex//l1tlb_resp.ex || l2resp_ex
        cached := info.cached
    }
    def collect(icop:CacopBus):Unit = {
        icop.ready := cop_ready
        cop_valid := icop.valid
        req_st_0_tag := icop.valid && (icop.isIdxInv || icop.isSt0Tag) && icop.isICache
        req_hit_inv  := icop.valid && icop.isHitInv && icop.isICache
    }
    def collect(l2srchresp:Flow[L2TLBResp]):Unit = {
        l2resp_valid := l2srchresp.valid
        //l2resp_ex := l2srchresp.ex
        //l2tlb_willresp := l2srchresp.will_resp
    }
    def collect(l1tlbresp:AddrTransResp):Unit = {
        l1tlb_resp.assignSomeByName(l1tlbresp)
    }
}
class ILookupInfo extends LookupInfo{
    val cfg  = ICacheConfig
    val pt   :UInt = (UInt(LISA.w_ppn bits))
    val mat  :UInt = UInt(2 bits)
    val ex   :Bool = (Bool()      )
    val ecode:UInt = (LISA.Ex.Code)
    val vidx :UInt = Reg(UInt(12 bits))
    def forceCached:Boolean = false
    val cached:Bool = if(forceCached) True else mat === U(1)
    def setFields(miss:IMiss, is_new:Bool, res:SavedTlbres):Unit = {
        miss.araddr := Mux(is_new, pa, res.pt@@vidx)
        miss.line   := Mux(is_new, cached, res.cc)
    }
}

class ICache extends Component{
    val io = new Bundle{
        val ireq  :InstReq  = slave (new InstReq )
        val iresp :InstResp = master(new InstResp)
        val iexcpt:IExcept  = out   (new IExcept )
        val isrch  = master(Stream(new AddrTransReq(LISA.TLB.FETCH)))
        val isrchresp = in(new AddrTransResp)
        val l2srch = master(Stream(new L2TLBSrchFet))
        val l2srchresp = slave(Flow(new L2TLBResp))
        val l2will_resp = in(Bool())
        val imiss :IMiss    = master(new IMiss   )
        val icop  :CacopBus = slave (new CacopBus)
    }

    io.isrch.valid := True//not use actually
    io.isrch.va := io.ireq.va
    val cfg  = ICacheConfig
    val tagv = cfg.getTagV
    val data = cfg.getIDat
    val ctrl = new ICacheCtrl
    val info = new ILookupInfo
    val repl = new LRUReplace
    val recv = new BusReceiver(io.iresp,io.imiss)
    val hits = Bits()
    val read_fin = RegNext(ctrl.read)
    val savedtlb = new SavedTlbres
    when(io.l2srchresp.valid){
        hits := io.l2srchresp.tagCompare(tagv)
    }.elsewhen(read_fin){
        hits := io.isrchresp.tagCompare(tagv)
    }.otherwise{
        hits := Vec(tagv.map(_.io.rd === U("1") @@ savedtlb.pt)).asBits
    }
    val hit = Mux(io.isrchresp.hit || io.l2srchresp.valid, hits.orR, False)
    val rdat = cfg.getVec2D(data(_)(_).io.rd)
    val inst = info.select(hits,rdat)
    val fill = new CacheFiller(io.imiss)
    info.va := io.ireq.va
    info.assignSomeByName(savedtlb)
    //info.ex := False
    //info.ecode := U(0)
    //info.pt := U(0)
    //info.mat := U(0)
    //info.vidx := U(0)
    val l2req_va = RegNextWhen(io.ireq.va,io.ireq.valid)
    val cc = Bool
    when(ctrl.read && !io.l2will_resp){ //should not update for l2 lookup
        info.vidx := info.nextVIdx
    }
    when(read_fin){
        //info.assignSomeByName(io.isrch.res)
        //info.vidx := info.nextVIdx
        when(io.l2srchresp.valid){
            //info.assignSomeByName(io.l2srchresp.payload)
            info.pt := io.l2srchresp.ptag
            info.ecode := io.l2srchresp.ecode
            info.ex := io.l2srchresp.ex
            info.mat := io.l2srchresp.mat.asUInt
            //info.vidx := info.nextVIdx
        }.otherwise{//.elsewhen(io.isrchresp.hit || io.isrchresp.ex){
            info.assignSomeByName(io.isrchresp)
            //info.vidx := info.nextVIdx
        }
    }
    io.l2srch.valid := False
    io.l2srch.vppn := l2req_va.takeHigh(LISA.TLB.VPPN_LEN+1).asUInt
    when(ctrl.l2req){
        io.l2srch.valid := !io.iresp.cancel
        //io.l2srch.vppn := l2req_va.takeHigh(LISA.TLB.VPPN_LEN+1).asUInt
    }
    ctrl.l2tlb_willresp := io.l2will_resp
    //ctrl.l2resp_cached := io.l2srchresp.payload.mat === B(1)
    ctrl.collect(io.ireq)
    ctrl.collect(io.iresp)
    ctrl.collect(io.imiss)
    ctrl.collect(info)
    ctrl.collect(recv)
    ctrl.collect(fill,cc)
    ctrl.collect(io.icop)
    ctrl.collect(io.isrchresp)
    ctrl.collect(io.l2srchresp)
    ctrl.hit := hit
    ctrl.l2tlb_ready := io.l2srch.ready
    
    val l1resp_valid = RegNext(io.isrch.valid)
    //val update_saved = l1resp_valid && io.isrchresp.hit || io.l2srchresp.valid
    when(read_fin){
        savedtlb.pt := Mux(io.l2srchresp.valid, io.l2srchresp.ptag, io.isrchresp.pt)
        savedtlb.mat := Mux(io.l2srchresp.valid, io.l2srchresp.mat.asUInt, io.isrchresp.mat)
        savedtlb.ex := Mux(io.l2srchresp.valid, io.l2srchresp.ex, io.isrchresp.ex)
        savedtlb.ecode := Mux(io.l2srchresp.valid, io.l2srchresp.ecode, io.isrchresp.ecode)
        //savedtlb.esubc := Mux(io.l2srchresp.valid, io.l2srchresp.esubc, io.isrchresp.esubc)
    }
    cc := Mux(read_fin, info.cached, savedtlb.cc)
    info.setFields(io.imiss, read_fin, savedtlb)
    //info.setFields(io.imiss)
    io.iexcpt.ex    := ctrl.data_ok && info.ex
    info.send_ex(io.iexcpt)
    io.iresp.valid  := recv.valid || ctrl.data_ok
    io.iresp.data   := Mux(recv.valid,recv.data,inst)


    val addr = UInt(8 bits)
    when(ctrl.start_hit_inv){
        info.vidx := io.icop.vidx
        info.pt   := io.icop.ptag
    }
    when(ctrl.init.valid){
        addr := ctrl.init.idx
    }.elsewhen(fill.valid || ctrl.do_hit_inv || ctrl.l2tlb_willresp){
        addr := info.widx
    }.elsewhen(ctrl.req_st_0_tag || ctrl.req_hit_inv){
        addr := io.icop.vidx(11 downto 4)
    }.otherwise{
        addr := info.ridx
    }
    val st_0_mask = cfg.getWayMask(io.icop.vidx(0 downto 0))
    val fill_tagv = fill.isWriteTagV(io.imiss)
    val fill_data = fill.isWriteData(io.imiss)
    val clear_tag = ctrl.st_0_tag | ctrl.do_hit_inv
    val tagv_wd:UInt = Mux(fill_tagv && !clear_tag,savedtlb.tagv,U(0))
    val data_wd:UInt = io.imiss.rdata
    val repl_mask = Range(0,cfg.n_way ).map(repl.way === U(_))
    val fill_wpos  = fill.getPos(info)
    val fill_wmsk = Range(0,cfg.n_bank).map(fill_wpos === U(_))
    val bank_rmsk = Range(0,cfg.n_bank).map(info.rpos === U(_))
    cfg.each_way(tagv(_)   .io.wd := tagv_wd)
    cfg.each_way(tagv(_)   .io.a  := addr   )
    cfg.each_blk(data(_)(_).io.wd := data_wd)
    cfg.each_blk(data(_)(_).io.a  := addr   )
    when(ctrl.init.valid){
        cfg.each_way(tagv(_).io.setWrite())
    }.elsewhen(ctrl.st_0_tag){
        cfg.each_way((i:Int)=>tagv(i).io.setWriteWhen(st_0_mask(i)))
    }.elsewhen(ctrl.do_hit_inv){
        cfg.each_way((i:Int)=>tagv(i).io.setWriteWhen(hits(i)))
    }.elsewhen(fill_tagv){
        cfg.each_way((i:Int)=>tagv(i).io.setWriteWhen(repl_mask(i)))
    }.elsewhen(ctrl.read){
        cfg.each_way((i:Int)=>tagv(i).io.setRead())
    }.otherwise{
        cfg.each_way(tagv(_).io.setIdle())
    }
    when(fill_data){
        cfg.each_blk((i:Int,j:Int)=>
            data(i)(j).io.setWriteWhen(repl_mask(i) && fill_wmsk(j))
        )
    }.elsewhen(ctrl.read){
        cfg.each_blk((i:Int,j:Int)=>
            data(i)(j).io.setReadWhen(bank_rmsk(j))
        )
    }.otherwise{
        cfg.each_blk((i:Int,j:Int)=>
            data(i)(j).io.setIdle()
        )
    }

    when(ctrl.read){
        repl.read(info.ridx)
    }
    when(ctrl.init.valid){
        repl.write(ctrl.init.idx,False)
    }.elsewhen(io.iresp.valid && io.iresp.allow){
        repl.write(info.widx,repl.update(fill.valid,hits))
    }

/*==========================Performer Counter================================*/
    val pm_icache_req = UInt(32 bits) setAsReg() init(0)
    when(io.ireq.valid && io.ireq.allow){
        pm_icache_req := pm_icache_req + U(1)
    }

    val pm_icache_uncache_req = UInt(32 bits) setAsReg() init(0)
    when(ctrl.s === ctrl.IState.sLookup && !ctrl.cached){
        pm_icache_uncache_req := pm_icache_uncache_req + U(1)
    }

    val pm_icache_miss = UInt(32 bits) setAsReg() init(0)
    when(ctrl.s === ctrl.IState.sLookup && ctrl.cached && ctrl.miss_valid){
        pm_icache_miss := pm_icache_miss + U(1)
    }

        
/*==========================Performer Counter================================*/
}
